from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree, is_undirected, to_undirected
from torch_geometric.utils.convert import to_scipy_sparse_matrix
import os 
import scipy.sparse as ss
import numpy as np
import torch
import argparse


def main():
    parser = argparse.ArgumentParser(description='Prepare data for XRTransformer')
    parser.add_argument('--dataset', type=str, default="ogbn-arxiv")
    parser.add_argument('--data_root_dir', type=str, default="./dataset")
    parser.add_argument('--max_deg', type=int, default=1000)
    parser.add_argument('--save_data_dir', type=str, default="./data_for_XRTransformer")
    parser.add_argument('--raw_text_path', type=str, required=True, help="Path of raw text (.txt file, each raw correspond to a node)")
    parser.add_argument('--tfidf_path', type=str, required=True, help="Path of tfidf feature (.npz file)")
    args = parser.parse_args()
    print(args)

    # Change args.save_data_dir to args.save_data_dir/args.dataset
    args.save_data_dir = os.path.join(args.save_data_dir, args.dataset) 

    dataset = PygNodePropPredDataset(name =  args.dataset, root = args.data_root_dir)
    data = dataset[0]
    edge_index = data.edge_index

    # # Make sure edge_index is undirected!!!
    if not is_undirected(edge_index):
        edge_index = to_undirected(edge_index)

    # Compute node degrees
    Degree = degree(edge_index[0])

    # Select node_idx such that its degree < args.max_deg
    Filtered_idx = torch.where(Degree<args.max_deg)[0]
    print('Number of original nodes:{}'.format(data.x.shape[0]))
    print('Number of filtered nodes:{}'.format(len(Filtered_idx)))

    # # Construct and save label matrix (adjacencey matrix) Y.
    Adj_csr = ss.csr_matrix(to_scipy_sparse_matrix(edge_index))
    ss.save_npz(os.path.join(args.save_data_dir,'Y_trn.npz'),Adj_csr[Filtered_idx])
    ss.save_npz(os.path.join(args.save_data_dir,'Y_tst.npz'),Adj_csr)
    print('XMC labels Y saved')

    # # Apply the same filtering for tfidf features and raw text
    # Raw text part. It is recommended to also put the raw text in args.save_data_dir.
    file_read = open(args.raw_text_path, 'r')
    R_Lines = file_read.readlines()

    file_write = open(os.path.join(args.save_data_dir,'Raw_text_filtered.txt'), 'w')

    count = 0
    cur_idx = 0
    for line in R_Lines:
        if Filtered_idx[count].item()==cur_idx:
            file_write.writelines(line)
            count += 1
            cur_idx += 1
        else:
            cur_idx += 1
        

    assert count ==  len(Filtered_idx) # We should exactly write len(Filtered_idx) lines!
    file_read.close()
    file_write.close()
    print('Raw text saved')

    # tfidf features part
    TFIDF = ss.load_npz(args.tfidf_path)
    ss.save_npz(os.path.join(args.save_data_dir,'tfidf_feature.npz'),TFIDF)
    TFIDF = TFIDF[Filtered_idx]
    ss.save_npz(os.path.join(args.save_data_dir,'tfidf_feature_filtered.npz'),TFIDF)
    print('tfidf features saved')

    print('Complete data preparation for XRTransformer!')

if __name__ == "__main__":
    main()
